import ekf_testv6 as ekf
import numpy as np
# import pykalmani
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import scipy.stats as ss
import matplotlib.transforms as mtransforms
#plt.rcParams["font.family"] = "Arial"


LENSpcolor='coral'
KFtacolor='darkgreen'

def cdf_subplot(ax,loc,lims,xs,ps,pcolor):
    #loc: threshold, thickness, start (years), width (in years)
    delta=20
    ax.plot([lims[2],2050], [loc[0], loc[0]], '-',color='sienna')#loc[2]-delta
    #ax.plot([loc[2]-delta,loc[2]], [loc[0], loc[0]+loc[1]/2.0], '-', color='lightcoral')
    #ax.plot([loc[2]-delta,loc[2]], [loc[0], loc[0]-loc[1]/2.0], '-', color='olive')

    
    axins = inset_axes(ax, width="100%", height="100%", \
                       bbox_to_anchor=((loc[2]-lims[2])/(lims[3]-lims[2]),\
                       (loc[0]-loc[1]/2.0 -lims[0])/(lims[1]-lims[0]), \
                       (loc[3])/(lims[3]-lims[2]), loc[1]/(lims[1]-lims[0])), \
                       bbox_transform=ax.transAxes, borderpad=0)
    axins.yaxis.set_label_position("right")
    axins.yaxis.tick_right()
    axins.set_ylim(0,1)
    axins.set_yticks([0,.5,1])
    axins.set_xticks(np.arange(1850,2050,25))
    axins.tick_params(axis='x', which='major', pad=1)
    axins.set_ylabel('probability')
    axins.set_xlim(loc[2],loc[2]+loc[3])
    axins.patch.set_alpha(0.6)
    axins.plot(xs,ps,'-',color=pcolor,linewidth=2)
    #axins.get_xaxis().set_visible(False) 
    return axins
axlabels=['a)','b)','c)','d)']

def draw_lines(ax,color2, q):
    ax.plot([1900,2100],[.5,.5],'k-.')
    ax.plot([1900,2100],[.159,.159],'k:',zorder=-1)
    #ax.plot([1900,2100],[.159,.159],color = color2,lw=3,zorder=-1)
    #if(full):
    ax.plot([1900,2100],[.841,.841],'k:')
    #    ax.plot([1900,2100],[.841,.841],color = color2,lw=3,zorder=-1)
    #else:
    #   ax.plot([2028,2100],[.841,.841],'k:')
    #    ax.plot([2028,2100],[.841,.841],color = color2,lw=3,zorder=-1)
    ax.set_ylim(0,1)
    ax.set_ylabel('$\mathbb{P}( (?)_t ≥ %s \degree C)$' %q)
    ax.yaxis.set_label_coords(-.18, .5)
    ax.set_xlabel('Year')

fineyears=np.linspace(1950, 2025, 750)
def hatch_below(ax3,thisdates,pvals,color1,shiftup=0):
    if(color1==LENSpcolor):
        hatchdir="/"
    else:
        hatchdir="\\"
    fineyears=np.linspace(1950, 2100, 1500)
    interpEBMK = np.interp(fineyears,thisdates,pvals) #thres1_ps[:,0]
    sind=np.argmax(interpEBMK>0.159)
    ax3.text(fineyears[sind]-0.1, shiftup, '(',color=color1,ha='center',size=16)
    eind=len(fineyears)- np.argmax(np.flip(interpEBMK)<0.841)
    if (eind!=len(fineyears)):
        ax3.text(fineyears[eind], shiftup, ')',color=color1,ha='center',size=16)
        ax3.plot([fineyears[sind],fineyears[eind]],[shiftup+0.02,shiftup+0.02],color=color1,lw=3)
#        eind=np.argmax(interpEBMK>np.max(interpEBMK)-0.0001)
    xinds=[0]
    xind=np.argmax(interpEBMK>0.5)
    direc=1
    while (xind>0):
        xind=xind+xinds[-1]
        xinds.append(xind)
        direc=-direc
        if(direc==1):
            xind=np.argmax(interpEBMK[xind:]>0.5)
        elif(direc==-1):
            xind=np.argmax(interpEBMK[xind:]<0.5)


    for x in xinds:
        #ax3.plot([fineyears[x],fineyears[x]],[shiftup+0.02,interpEBMK[x]],"-",color=color1,alpha=0.5)
        print(fineyears[x])

    #ax3.fill_between(fineyears[sind:eind], np.zeros((eind-sind)), interpEBMK[sind:eind], color=color2,hatch=hatchdir,edgecolor=color1,alpha=0.2)

xh1s=ekf.xh1s
xh0s=ekf.xh0s
stdS=ekf.stdS
stdP=ekf.stdP

fig, ((ax1,ax12),(ax4,ax2))= plt.subplots(2, 2, figsize=(8,7.5), gridspec_kw={'wspace':0.45,'hspace':0.35 })
bigfonts=10;
bigfontb=12;
#fig21, ax1= plt.subplots(1, 1, figsize=(5,6))
#fig22, ax4= plt.subplots(1, 1, figsize=(5,6))
#plt.subplots_adjust(top=0.95, bottom=0.05)

plt.sca(ax4)
#plt.plot(dates,xhatminus,'b-',label='a priori EKF estimate', linewidth=3.0)
ax4.set_title('Temperature Forecast Probabilities \n of Threshold Crossings with EBM-KF-uf',fontsize=bigfontb)
xh0s[0]=xh1s[0]
ax4.plot(ekf.dates,xh0s,'-.',label='$\^{T }_{t|t-1}$', color=ekf.colorekf, linewidth=0.5)

ax4.fill_between(ekf.dates, xh0s-2*stdS, xh0s+2*stdS,label="$\pm 2\sqrt{\hat{s}^T_t}$", color=ekf.coloruncert)
#ax.plot(ekf.dates,ekf.temps,'o',markersize=2,color=ekf.colorgrey,label='noisy HadCRUT5 GMST measurements $Y_{n}$')
lims=[286.1,288.3,1950,2025] #ymin ymax xmin xmax

def draw_threshold_lines(ax):
    ax.plot([lims[2],lims[3]], [ekf.pindavg, ekf.pindavg], '-',color='sienna')
    ax.text(1952, ekf.pindavg-0.15, 'preindustrial',color='sienna', fontsize=bigfontb)
    ax.plot([lims[2],lims[3]], [ekf.pindavg+.5, ekf.pindavg+.5], '-',color='sienna')
    ax.text(1952, ekf.pindavg-0.15+0.5, '+0.5°C',color='sienna',fontsize=bigfontb)
    ax.text(1952, ekf.pindavg-0.15+1, '+1.0°C',color='sienna',fontsize=bigfontb)
    ax.text(1952, ekf.pindavg-0.15+1.5, '+1.5°C',color='sienna',fontsize=bigfontb)
    ax.plot([lims[2],lims[3]], [ekf.pindavg+1.5, ekf.pindavg+1.5], '-',color='sienna')
    ax.tick_params( direction = 'in',bottom=True, top=True, left=True, right=False )
    #ax.plot([1998,1998], [lims[0],lims[1]], '-',color='k', lw=0.5)

ekf.plot_boilerplate(ax4)
draw_threshold_lines(ax4)

ax4.set_xlim([lims[2],lims[3]])
ax4.set_ylim([lims[0],lims[1]])
ax4.set_yticks(np.arange(286.2,288.3,0.4))
ax4.plot([lims[3]+100,lims[3]+100], [ekf.pindavg, ekf.pindavg], '-',color=ekf.colorekf,linewidth=2,label='$\mathbb{P}(Y_{t+1 | t} ≥ 1\degree C)$')
handles, labels = ax4.get_legend_handles_labels()
order = [0,2,1]
ax4.legend([handles[idx] for idx in order],[labels[idx] for idx in order],prop={'size': bigfonts-1}, loc='lower right',framealpha=1)
#axb = ax.twinx()
#axb.set_yticks(np.arange(12.4,18.4,0.2))
#axb.set_ylim(lims[0]- 273.15, lims[1]- 273.15)
#axb.set_ylabel('Temperature (°C)')

#plt.tick_params( direction = 'in',bottom=True, top=True, left=True, right=True )#

thres1=ekf.pindavg+0.5
thres2=ekf.pindavg+1.0
thres3=ekf.pindavg+1.5
thres4=ekf.pindavg+2.0
thres5=ekf.pindavg+2.5
t1s=1975
t1l=30
t2s=1998
t2l=30
t3s=2015
t3l=30
t4s=2030
t4l=30
t5s=2045
t5l=30
boxsize=.65

thres1_ps=np.zeros((ekf.n_iters,1))
thres2_ps=np.zeros((ekf.n_iters,1))
thres3_ps=np.zeros((ekf.n_iters,1))
thres1_cps=np.zeros((ekf.n_iters,1))
thres2_cps=np.zeros((ekf.n_iters,1))
thres3_cps=np.zeros((ekf.n_iters,1))
for i in range(ekf.n_iters):
    thres1_ps[i]=1-ss.norm.cdf(thres1,xh0s[i],stdS[i])
    thres2_ps[i]=1-ss.norm.cdf(thres2,xh0s[i],stdS[i])
    thres3_ps[i]=1-ss.norm.cdf(thres3,xh0s[i],stdS[i])
    thres1_cps[i]=1-ss.norm.cdf(thres1,xh1s[i],stdP[i])
    thres2_cps[i]=1-ss.norm.cdf(thres2,xh1s[i],stdP[i])
    thres3_cps[i]=1-ss.norm.cdf(thres3,xh1s[i],stdP[i])
#print(thres1_ps)
#print(thres2_ps)
#cdf_subplot(ax,[thres1, boxsize, t1s,t1l],lims,ekf.dates,thres1_ps)
cdf_subplot(ax4,[thres2, boxsize, t2s,2025-t2s],lims,ekf.dates,thres2_ps,ekf.colorekf)

#fig, (ax12, ax2) = plt.subplots(2, 1, figsize=(7,10))
#fig11, ax12 = plt.subplots(1, 1, figsize=(5,6))
#fig12, ax2 = plt.subplots(1, 1, figsize=(5,6))
#plt.subplots_adjust(top=0.95, bottom=0.05)
import summary_of_sims as sum_sims
sum_sims.plot_many_sims(ax2)
ekf.plot_boilerplate(ax2)
ax2.set_ylim([lims[0],lims[1]])
ax2.set_xlim([lims[2],lims[3]])
ax2.set_yticks(np.arange(286.2,288.3,0.4))
ax2.set_title('Temperature Forecast Probabilities \nof Threshold Crossings with CESM2 LENS',fontsize=bigfontb)
draw_threshold_lines(ax2)
ax2.plot([lims[3]+100,lims[3]+100], [ekf.pindavg, ekf.pindavg], '-',color=LENSpcolor,linewidth=2,label="$\mathbb{P} ((Y_t)_j ≥ 1\degree C )$")
ax2.set_ylabel('Temperature (K)')
handles, labels = ax2.get_legend_handles_labels()
order = [1,0,2]
ax2.legend([handles[idx] for idx in order],['$\overline{(Y_t)_j}$','$(Y_t)_j$',labels[2]],prop={'size': bigfonts-1},loc='lower right',framealpha=1)
#ax2.set_ylabel("")
#ax2.yaxis.set_ticklabels([])
#axb = ax2.twinx()
#axb.set_yticks(np.arange(12.4,18.4,0.2))
#axb.set_ylim(lims[0]- 273.15, lims[1]- 273.15)
#axb.set_ylabel('Temperature (°C)')

#compute moving averages on sims
N = 21
cN=int(np.ceil(N/2));fN=int(np.floor(N/2))
allsimsavgd=np.zeros(np.shape(sum_sims.twTR));allsimsavgd[:]=np.nan
for s in range(sum_sims.sms):
    tempssim = sum_sims.twTR[s,:]
#### OTHER WAY TO COMPUTE MOVING AVERAGE
    cumsum, moving_aves = [0], []
    for i, x in enumerate(tempssim, 1):
        cumsum.append(cumsum[i-1] + x)
        if i<N/2:
            moving_aves.append(np.nan)
        if i>=N:
            moving_ave = (cumsum[i] - cumsum[i-N])/N
            #can do stuff with moving_ave here
            if(np.isnan(tempssim[i-N:i]).sum()>0):
                moving_ave=np.nan
            moving_aves.append(moving_ave)
    for i in range(int(np.floor(N/2))):
        moving_aves.append(np.nan)
##    moving_aves=np.zeros(len(tempssim)); moving_aves[:]=np.nan
##    for i in range((fN),(len(tempssim)-cN+1)):
##        lasta=i+cN;firsta=i-fN
##        moving_aves[i] = np.mean(tempssim[(firsta):lasta])
    allsimsavgd[s,:]=moving_aves

ax12.set_title('Climate State Probabilities \n of Threshold Crossings with CESM2 LENS', fontsize=bigfontb)
#plt.plot(ekf.dates,xh1s,'-',label='a posteriori EKF state estimate $\^{T }_{n}}$', color=ekf.colorekf)
ax12.plot(sum_sims.dates[ekf.fN:-ekf.cN+1],allsimsavgd[0,ekf.fN:-ekf.cN+1],'-',label='$( \overline{_{21}Y_t} )_j$', color=ekf.colorstate, linewidth=0.5,zorder=1)
for r in range(1,sum_sims.sms):
    ax12.plot(sum_sims.dates[ekf.fN:-ekf.cN+1],allsimsavgd[r,ekf.fN:-ekf.cN+1],'-', color=sum_sims.colorgen(r,(70,30,30),(104,195,121)), linewidth=0.5,zorder=1) #(52./255, 235./255, 235./255)
allsimsmean=np.nanmean(allsimsavgd,axis=0)
ax12.plot(sum_sims.dates[ekf.fN:-ekf.cN+1],allsimsmean[ekf.fN:-ekf.cN+1],'-',label='$\overline{( \overline{_{21}Y_t} )_j}$', color=ekf.colorekf,zorder=3,linewidth=0.5)


ekf.plot_boilerplate(ax12)
ax12.set_xlim([lims[2],lims[3]])
ax12.set_ylim([lims[0],lims[1]])
ax12.set_yticks(np.arange(286.2,288.3,0.4))
#ax12.set_ylabel("")
#ax12.yaxis.set_ticklabels([])
#axb = ax12.twinx()
#axb.set_yticks(np.arange(12.,18.4,0.2))
#axb.set_ylim(lims[0]- 273.15, lims[1]- 273.15)
#axb.set_ylabel('Temperature (°C)')
draw_threshold_lines(ax12)
ax12.plot([lims[3]+100,lims[3]+100], [ekf.pindavg, ekf.pindavg], '-',color=LENSpcolor,linewidth=2,label="$\mathbb{P} ( (\overline{_{21}Y_t})_j ≥ 1\degree C )$")
handles, labels = ax12.get_legend_handles_labels()
order = [1,0,2]
ax12.legend([handles[idx] for idx in order],[labels[idx] for idx in order],prop={'size': bigfonts-1}, loc='lower right',framealpha=1)




thres1_ss=np.zeros((sum_sims.smyrs,1))
thres2_ss=np.zeros((sum_sims.smyrs,1))
thres1_css=np.zeros((sum_sims.smyrs,1))
thres2_css=np.zeros((sum_sims.smyrs,1))
thres3_ss=np.zeros((sum_sims.smyrs,1))
thres4_ss=np.zeros((sum_sims.smyrs,1))
thres5_ss=np.zeros((sum_sims.smyrs,1))
thres3_css=np.zeros((sum_sims.smyrs,1))
thres4_css=np.zeros((sum_sims.smyrs,1))
thres5_css=np.zeros((sum_sims.smyrs,1))
ssn=np.zeros((sum_sims.smyrs,1))

#compute moving avg kalman

moving_aves30=ekf.moving_aves;
fineyears=np.linspace(1950, 2025, 750,endpoint=False) #10 per year
moving_aves30_interp = np.interp(fineyears,ekf.dates,moving_aves30)
switch_yr_idx = np.argmax(np.isnan(moving_aves30_interp))
x1crossavgyr=np.argmax(moving_aves30_interp>thres1)

from scipy import stats
slope_30, intercept_30, r_30, p_30, std_err_30 = stats.linregress(ekf.dates[-30:],ekf.temps[-30:])
err_30 = slope_30 * ekf.dates[-30:] + intercept_30 - ekf.temps[-30:]
max_err30=np.max(err_30)
min_err30=np.min(err_30)
moreyrs=20
yearsfwd=np.linspace(ekf.n_iters+1850, ekf.n_iters+1850+moreyrs, moreyrs,endpoint=False) #1 per year
add_yrs= slope_30 * yearsfwd + intercept_30
temps_projj_max = np.concatenate((ekf.temps,add_yrs+max_err30))
temps_projj_min = np.concatenate((ekf.temps,add_yrs+min_err30))
temps_projj = np.concatenate((ekf.temps,add_yrs))
moving_aves_projj_max=np.empty(ekf.n_iters+moreyrs)
moving_aves_projj_max[:] = np.nan
moving_aves_projj_min=np.copy(moving_aves_projj_max)
moving_aves_projj_c=np.copy(moving_aves_projj_max)
for i in range((ekf.fN),(len(temps_projj_max)-ekf.cN+1)):
    lasta=i+ekf.cN;firsta=i-ekf.fN
    moving_aves_projj_max[i] = np.mean(temps_projj_max[firsta:lasta]);
    moving_aves_projj_min[i] = np.mean(temps_projj_min[firsta:lasta]);
    moving_aves_projj_c[i] = np.mean(temps_projj[firsta:lasta]);
    
moving_aves30_interp_max = np.interp(fineyears,np.arange(1850,ekf.n_iters+moreyrs+1850),moving_aves_projj_max)
moving_aves30_interp_min = np.interp(fineyears,np.arange(1850,ekf.n_iters+moreyrs+1850),moving_aves_projj_min)
x2crossavgyr_max=np.argmax(moving_aves30_interp_max>thres2)
x2crossavgyr_min=np.argmax(moving_aves30_interp_min>thres2)
##plt.figure()
##plt.plot(fineyears,moving_aves30_interp_max)
##plt.plot(fineyears,moving_aves30_interp_min)
##plt.plot(np.arange(1850,ekf.n_iters+moreyrs+1850),moving_aves_projj_c)
##plt.plot(fineyears,moving_aves30_interp)
##plt.plot(yearsfwd,add_yrs)
##
###print(xcrossavgyr)
##plt.figure()
##plt.plot(ekf.dates,moving_aves30)
##plt.plot([1900,2000],[thres1,thres1])

for j in range(sum_sims.smyrs):
    num_sims_now=len([i for i in sum_sims.twTR[:,j] if i > 0])
    thres1_ss[j]=len([i for i in sum_sims.twTR[:,j] if i > thres1])/float(num_sims_now)
    thres2_ss[j]=len([i for i in sum_sims.twTR[:,j] if i > thres2])/float(num_sims_now)
    thres3_ss[j]=len([i for i in sum_sims.twTR[:,j] if i > thres3])/float(num_sims_now)
    thres4_ss[j]=len([i for i in sum_sims.twTR[:,j] if i > thres4])/float(num_sims_now)
    thres5_ss[j]=len([i for i in sum_sims.twTR[:,j] if i > thres5])/float(num_sims_now)
    num_avs_now=len([i for i in allsimsavgd[:,j] if i > 0])
    if (num_avs_now==0):
        num_avs_now=np.nan
    thres1_css[j]=len([i for i in allsimsavgd[:,j] if i > thres1])/float(num_avs_now)
    thres2_css[j]=len([i for i in allsimsavgd[:,j] if i > thres2])/float(num_avs_now)
    thres3_css[j]=len([i for i in allsimsavgd[:,j] if i > thres3])/float(num_avs_now)
    thres4_css[j]=len([i for i in allsimsavgd[:,j] if i > thres4])/float(num_avs_now)
    thres5_css[j]=len([i for i in allsimsavgd[:,j] if i > thres5])/float(num_avs_now)
#cdf_subplot(ax2,[thres1, boxsize, t1s,t1l],lims,sum_sims.dates,thres1_ss,ekf.pcolor)
cdf_subplot(ax2,[thres2, boxsize, t2s,2025-t2s],lims,sum_sims.dates,thres2_ss,LENSpcolor)



#cdf_subplot(ax12,[thres1, boxsize, t1s,t1l],lims,sum_sims.dates,thres1_css,ekf.pcolor)
cdf_subplot(ax12,[thres2, boxsize, t2s,2025-t2s],lims,sum_sims.dates,thres2_css,LENSpcolor)


startf=2023 #not 2024!
endf=2100
sup=0.05


plt.rcParams['figure.figsize'] = (7,7)



fdates=np.arange(startf,endf,1/4)
fdata = np.load("SSP370.npz")
dif_array=np.abs(fdata['gmstheights']-thres3)
incp3=dif_array.argmin()
dif_array=np.abs(fdata['gmstheights']-thres4)
incp4=dif_array.argmin()
dif_array=np.abs(fdata['gmstheights']-thres5)
incp5=dif_array.argmin()
fcss=(1000-fdata['cgmst'])/1000
ftpss=(1000-fdata['tpgmst'])/1000

fdata = np.load("SSP370ta.npz")
dif_array=np.abs(fdata['gmstheights']-thres3)
incp32=dif_array.argmin()
dif_array=np.abs(fdata['gmstheights']-thres4)
incp42=dif_array.argmin()
dif_array=np.abs(fdata['gmstheights']-thres5)
incp52=dif_array.argmin()
fcss2=(1000-fdata['cgmst'])/1000
ftpss2=(1000-fdata['tpgmst'])/1000


fig2, axes=plt.subplots(nrows=2, ncols=5,figsize=(14,8))
fig2.subplots_adjust(wspace=0.45,hspace=0.45)

if(True):
    data3 = np.genfromtxt(open("KF6projectionSSP.csv", "rb"),dtype=float, delimiter=',')
    #370 is #3 in the list
    (fxhat,fP0)=ekf.ekf_future(startf,ekf.Plastretain,ekf.xlastretain,endf+1,np.log10(data3[:,1+3]),data3[:,6+3],noVolcs=True)
    fxh1s=fxhat[:,0]
    fstdP=np.sqrt(np.abs(fP0))[:,0,0]
    fstdS=np.sqrt(np.abs(fP0+ekf.Rtvara))[:,0,0]
    thres3_fcps=np.zeros((endf-startf,1))
    thres4_fcps=np.zeros((endf-startf,1))
    thres5_fcps=np.zeros((endf-startf,1))
    thres3_fps=np.zeros((endf-startf,1))
    thres4_fps=np.zeros((endf-startf,1))
    thres5_fps=np.zeros((endf-startf,1))
    for i in range(endf-startf):
        thres3_fcps[i]=1-ss.norm.cdf(thres3,fxh1s[i],fstdP[i])
        thres4_fcps[i]=1-ss.norm.cdf(thres4,fxh1s[i],fstdP[i])
        thres5_fcps[i]=1-ss.norm.cdf(thres5,fxh1s[i],fstdP[i])
        thres3_fps[i]=1-ss.norm.cdf(thres3,fxh1s[i],fstdS[i])
        thres4_fps[i]=1-ss.norm.cdf(thres4,fxh1s[i],fstdS[i])
        thres5_fps[i]=1-ss.norm.cdf(thres5,fxh1s[i],fstdS[i])
    fdate0=np.arange(startf,endf,1)
    axes[0,4].plot(fdate0,thres5_fcps,'-',color='blue',linewidth=2)
    hatch_below(axes[0,4],fdate0,thres5_fcps[:,0],'blue',sup/2)
    axes[1,4].plot(fdate0,thres5_fps,'-',color='blue',linewidth=2)
    hatch_below(axes[1,4],fdate0,thres5_fps[:,0],'blue',sup/2)
    axes[0,3].plot(fdate0,thres4_fcps,'-',color='blue',linewidth=2)
    hatch_below(axes[0,3],fdate0,thres4_fcps[:,0],'blue',sup/2)
    axes[1,3].plot(fdate0,thres4_fps,'-',color='blue',linewidth=2)
    hatch_below(axes[1,3],fdate0,thres4_fps[:,0],'blue',sup/2)
    axes[0,2].plot(fdate0,thres3_fcps,'-',color='blue',linewidth=2)
    hatch_below(axes[0,2],fdate0,thres3_fcps[:,0],'blue',sup/2)
    hatch_below(axes[1,2],fdate0,thres3_fps[:,0],'blue',sup/2)


wt_opt_depths = 1/(ekf.opt_depth+9.7279)
N = 30
nwt_opt_depths=np.empty(len(ekf.opt_depth)); nwt_opt_depths[:]=ekf.involcavg
cN=int(np.ceil(N/2))
fN=int(np.floor(N/2))
for i in range((fN),(len(nwt_opt_depths)-cN+1)):
    lasta=i+cN;firsta=i-fN
    nwt_opt_depths[i] = (np.sum(wt_opt_depths[(firsta):i+1])+ ekf.involcavg*(cN-1))/N
    #computing half-average - future is assumed to be the average 
nopt_depths=(1/nwt_opt_depths-9.7279)
ekf.opt_depth=nopt_depths #reassign newly computed optical depths

xh1s=np.copy(ekf.xh1s)
stdP=np.copy(ekf.stdP)
(smoothed_xhat,P2s,S2s,xh20s)=ekf.ekf_run(ekf.observ,ekf.n_iters,retPs=2)
xh2s=smoothed_xhat[:,0]
stdP2s=np.sqrt(np.abs(P2s))[:,0,0]
stdS2s=np.sqrt(np.abs(S2s))[:,0,0]

thres1_cps2=np.zeros((ekf.n_iters,1))
thres2_cps2=np.zeros((ekf.n_iters,1))
thres3_cps2=np.zeros((ekf.n_iters,1))
thres1_ps2=np.zeros((ekf.n_iters,1))
thres2_ps2=np.zeros((ekf.n_iters,1))
thres3_ps2=np.zeros((ekf.n_iters,1))

for i in range(ekf.n_iters):
    thres1_cps2[i]=1-ss.norm.cdf(thres1,xh2s[i],stdP2s[i])
    thres2_cps2[i]=1-ss.norm.cdf(thres2,xh2s[i],stdP2s[i])
    thres3_cps2[i]=1-ss.norm.cdf(thres3,xh2s[i],stdP2s[i])
    thres1_ps2[i]=1-ss.norm.cdf(thres1,xh20s[i],stdS2s[i])
    thres2_ps2[i]=1-ss.norm.cdf(thres2,xh20s[i],stdS2s[i])
    thres3_ps2[i]=1-ss.norm.cdf(thres3,xh20s[i],stdS2s[i]) 

ax3 = axes[0,4]
draw_lines(ax3,ekf.colorstate,2.5)
ax3.set_xlim(t5s,t5s+t4l)
#ax3.plot(fdates,fcss[incp5,:],'-',color=ekf.pcolor,linewidth=2, label='Extended Kalman Filter on Measurements')
ax3.plot(fdates,fcss2[incp52,:],'-',color=KFtacolor,linewidth=2)
ax3.plot(sum_sims.dates,thres5_css,'-',color=LENSpcolor,linewidth=2, label='CESM2 LENS Model Simulations')
ax3.set_title('e) +2.5°C (SSP370)')
#hatch_below(ax3,fdates,fcss[incp5,:],ekf.pcolor,sup/2)
hatch_below(ax3,fdates,fcss2[incp52,:],KFtacolor)
hatch_below(ax3,sum_sims.dates,thres5_css[:,0],LENSpcolor,sup)
minor_xticks = np.arange(t5s,t5s+t5l, 5)
ax3.set_xticks(minor_xticks, minor=True)

ax3 = axes[1,4]
draw_lines(ax3,ekf.coloruncert,2.5)
ax3.set_xlim(t5s,t5s+t4l)
ax3.plot(fdates,ftpss[incp5,:],'-',color=ekf.pcolor,linewidth=2, label='EBM - Kalman')
#ax3.plot(fdates,ftpss2[incp52,:],'-',color=KFtacolor,linewidth=2)
ax3.plot(sum_sims.dates,thres5_ss,'-',color=LENSpcolor,linewidth=2, label='LENS2 Sims')
ax3.set_title('j) +2.5°C (SSP370)')
hatch_below(ax3,fdates,ftpss[incp5,:],ekf.pcolor)
#hatch_below(ax3,fdates,ftpss2[incp52,:],KFtacolor)
hatch_below(ax3,sum_sims.dates,thres5_ss[:,0],LENSpcolor,sup)
minor_xticks = np.arange(t5s,t5s+t5l, 5)
ax3.set_xticks(minor_xticks, minor=True)


ax3 = axes[0,3]
draw_lines(ax3,ekf.colorstate,2.0)
ax3.set_xlim(t4s,t4s+t4l)
#ax3.plot(fdates,fcss[incp4,:],'-',color=ekf.pcolor,linewidth=2, label='Extended Kalman Filter on Measurements')
ax3.plot(fdates,fcss2[incp42,:],'-',color=KFtacolor,linewidth=2)
ax3.plot(sum_sims.dates,thres4_css,'-',color=LENSpcolor,linewidth=2, label='CESM2 LENS Model Simulations')
ax3.set_title('d) +2.0°C (SSP370)')
#hatch_below(ax3,fdates,fcss[incp4,:],ekf.pcolor,sup/2)
hatch_below(ax3,fdates,fcss2[incp42,:],KFtacolor)
hatch_below(ax3,sum_sims.dates,thres4_css[:,0],LENSpcolor,sup)
minor_xticks = np.arange(t4s,t4s+t4l, 5)
ax3.set_xticks(minor_xticks, minor=True)

ax3 = axes[1,3]
draw_lines(ax3,ekf.coloruncert,2.0)
ax3.set_xlim(t4s,t4s+t4l)
ax3.plot(fdates,ftpss[incp4,:],'-',color=ekf.pcolor,linewidth=2, label='EBM - Kalman')
#ax3.plot(fdates,ftpss2[incp42,:],'-',color=KFtacolor,linewidth=2)
ax3.plot(sum_sims.dates,thres4_ss,'-',color=LENSpcolor,linewidth=2, label='LENS2 Sims')
ax3.set_title('i) +2.0°C (SSP370)')
hatch_below(ax3,fdates,ftpss[incp4,:],ekf.pcolor)
#hatch_below(ax3,fdates,ftpss2[incp42,:],KFtacolor)
hatch_below(ax3,sum_sims.dates,thres4_ss[:,0],LENSpcolor,sup)
minor_xticks = np.arange(t4s,t4s+t4l, 5)
ax3.set_xticks(minor_xticks, minor=True)

ax3 = axes[0,2]
draw_lines(ax3,ekf.colorstate,1.5)
ax3.set_xlim(t3s,t3s+t3l)
ax3.plot(ekf.dates,thres3_cps2,'-',color=KFtacolor,linewidth=2, label="EBM-KF-ta")
ax3.plot(ekf.dates,thres3_cps,'-',color=ekf.colorekf,linewidth=2, label="EBM-KF-uf")
#ax3.plot(fdates,fcss[incp3,:],'-',color=ekf.pcolor,linewidth=2, label='VOLC')
ax3.plot(fdates,fcss2[incp32,:],'-',color=KFtacolor,linewidth=2)
ax3.plot(sum_sims.dates,thres3_css,'-',color=LENSpcolor,linewidth=2, label='CESM2 LENS Model Simulations')
ax3.set_title('c) +1.5°C (SSP370)')
ax3.text((t3s+t3s+t3l)/2,1.15,'                                                                  Climate State                                                                  ',
         size=18,horizontalalignment='center',backgroundcolor=ekf.colorstate)
#hatch_below(ax3,fdates,fcss[incp3,:],ekf.pcolor,sup/2)
hatch_below(ax3,fdates,fcss2[incp32,:],KFtacolor)
hatch_below(ax3,sum_sims.dates,thres3_css[:,0],LENSpcolor,sup)
minor_xticks = np.arange(t3s,t3s+t3l, 5)
ax3.set_xticks(minor_xticks, minor=True)
ax3.text(ekf.dates[-1],thres3_cps2[-1],'$\\ast$',color=KFtacolor,ha='center',va='center',size=20)
ax3.text(ekf.dates[-1]-3,thres3_cps2[-1]+.045,format(thres3_cps[-1][0], ".5f"),color=KFtacolor,ha='center',va='center',size=10,rotation=15)

ax3 = axes[1,2]
ax3.plot(fineyears[x1crossavgyr],-0.5,'.',color='yellow', markeredgecolor="mediumseagreen",label="Running Mean $\overline{_{30}Y_n}$",zorder=3,alpha=1, markersize=18, markeredgewidth=3)
draw_lines(ax3,ekf.coloruncert,1.5) #previously had dashed line behind here
ax3.set_xlim(t3s,t3s+t3l)
ax3.plot(ekf.dates,thres3_ps2-30,'-',color=KFtacolor,linewidth=2, label="EBM-KF-ta") #plot offscreen to get color in legend
ax3.plot(ekf.dates,thres3_ps,'-',color=ekf.colorekf,linewidth=2, label="EBM-KF-uf")
ax3.plot(fdates,ftpss[incp3,:],'-',color=ekf.pcolor,linewidth=2, label='EBM-KF Volc')
#ax3.plot(fdates,ftpss2[incp32,:],'-',color=KFtacolor,linewidth=2)
ax3.plot(sum_sims.dates,thres3_ss,'-',color=LENSpcolor,linewidth=2, label='LENS2')
ax3.set_title('h) +1.5°C (SSP370)')
ax3.text((t3s+t3s+t3l)/2,1.15,'                                                             Temperature Forecast                                                             ',
         size=18,horizontalalignment='center',backgroundcolor=ekf.coloruncert)
hatch_below(ax3,fdates,ftpss[incp3,:],ekf.pcolor)
#hatch_below(ax3,fdates,ftpss2[incp32,:],KFtacolor)
hatch_below(ax3,sum_sims.dates,thres3_ss[:,0],LENSpcolor,sup)
minor_xticks = np.arange(t3s,t3s+t3l, 5)
ax3.set_xticks(minor_xticks, minor=True)
axes[1,2].plot(fdate0,thres3_fps,'-',color='blue',linewidth=2,label="EBM-KF Unif")
ax3.text(ekf.dates[-1],thres3_ps[-1],'$\\ast$',color=ekf.colorekf,ha='center',va='center',size=20)
ax3.text(ekf.dates[-1]-3,thres3_ps[-1]+0.05,format(thres3_ps[-1][0], ".1%"),color=ekf.colorekf,ha='center',va='center',size=15,rotation=70,backgroundcolor='white',zorder=1)
ax3.legend(loc='upper center', prop={'size': 14}, bbox_to_anchor=(0.5, -0.13), ncol=6)

######
ax3 = axes[0,1]
draw_lines(ax3,ekf.colorstate,1.0)
ax3.set_xlim(t2s,t2s+t2l)
#ax3.plot(ekf.dates,thres2_cps,'-',color=ekf.colorekf,linewidth=2, label='Extended Kalman Filter on Measurements')
ax3.plot(sum_sims.dates,thres2_css,'-',color=LENSpcolor,linewidth=2, label='CESM2 LENS Model Simulations')
ax3.set_title('b) +1.0°C Threshold')
#hatch_below(ax3,ekf.dates,thres2_cps[:,0],ekf.colorekf)
hatch_below(ax3,sum_sims.dates,thres2_css[:,0],LENSpcolor,sup/2)
ax3.plot(ekf.dates,thres2_cps2,'-',color=KFtacolor,linewidth=2)
hatch_below(ax3,ekf.dates,thres2_cps2[:,0],KFtacolor)
minor_xticks = np.arange(1995,t2s+t2l, 5)
ax3.set_xticks(minor_xticks, minor=True)
ax3.plot(fineyears[x2crossavgyr_max:x2crossavgyr_min+1],np.repeat(0.5,x2crossavgyr_min-x2crossavgyr_max+1),'.',color='mediumseagreen',markersize=22 )
ax3.plot(fineyears[x2crossavgyr_max:x2crossavgyr_min+1],np.repeat(0.5,x2crossavgyr_min-x2crossavgyr_max+1),'.',color='yellow',zorder=3,alpha=1, markersize=10)

ax3 = axes[1,1]
draw_lines(ax3,ekf.coloruncert,1.0)
ax3.set_xlim(t2s,t2s+t2l)
ax3.plot(ekf.dates,thres2_ps,'-',color=ekf.colorekf,linewidth=2, label='EBM - Kalman')
ax3.plot(sum_sims.dates,thres2_ss,'-',color=LENSpcolor,linewidth=2, label='LENS2 Sims')
ax3.set_title('g) +1.0°C Threshold')
hatch_below(ax3,ekf.dates,thres2_ps[:,0],ekf.colorekf)
hatch_below(ax3,sum_sims.dates,thres2_ss[:,0],LENSpcolor,sup/2)
#ax3.text(fineyears[x2crossavgyr],0,'$\\ast$',color='k',ha='center',size=14)
#ax3.plot(ekf.dates,thres2_ps2,'-',color=KFtacolor,linewidth=2)
#hatch_below(ax3,ekf.dates,thres2_ps2[:,0],KFtacolor,sup/2)
minor_xticks = np.arange(1995,t2s+t2l, 5)
ax3.set_xticks(minor_xticks, minor=True)
ax3.plot(fineyears[x2crossavgyr_max:x2crossavgyr_min+1],np.repeat(0.5,x2crossavgyr_min-x2crossavgyr_max+1),'.',color='mediumseagreen',markersize=22 )
ax3.plot(fineyears[x2crossavgyr_max:x2crossavgyr_min+1],np.repeat(0.5,x2crossavgyr_min-x2crossavgyr_max+1),'.',color='yellow',zorder=3,alpha=1, markersize=10)

ax3 = axes[0,0]
draw_lines(ax3,ekf.colorstate,0.5)
ax3.set_xlim(t1s,t1s+t1l)
#ax3.plot(ekf.dates,thres1_cps,'-',color=ekf.colorekf,linewidth=2, label='Extended Kalman Filter on Measurements')
ax3.plot(sum_sims.dates,thres1_css,'-',color=LENSpcolor,linewidth=2, label='CESM2 LENS Model Simulations')
ax3.set_title('a) +0.5°C Threshold')
#hatch_below(ax3,ekf.dates,thres1_cps[:,0],ekf.colorekf)
hatch_below(ax3,sum_sims.dates,thres1_css[:,0],LENSpcolor,sup/2)
minor_xticks = np.arange(t1s,t1s+t1l, 5)
ax3.set_xticks(minor_xticks, minor=True)
ax3.plot(ekf.dates,thres1_cps2,'-',color=KFtacolor,linewidth=2)
hatch_below(ax3,ekf.dates,thres1_cps2[:,0],KFtacolor)
ax3.plot(fineyears[x1crossavgyr],0.5,'.',color='yellow', markeredgecolor="mediumseagreen",zorder=3,alpha=1, markersize=18, markeredgewidth=3)
#ax3.text(fineyears[x1crossavgyr],0.5,'$\\ast$',color='darkorange',ha='center',size=20)

ax3 = axes[1,0]
draw_lines(ax3,ekf.coloruncert.count,0.5)
ax3.set_xlim(t1s,t1s+t1l)
ax3.plot(ekf.dates,thres1_ps,'-',color=ekf.colorekf,linewidth=2, label='EBM-KF')
ax3.plot(sum_sims.dates,thres1_ss,'-',color=LENSpcolor,linewidth=2, label='LENS2')
ax3.set_title('f) +0.5°C Threshold')
hatch_below(ax3,ekf.dates,thres1_ps[:,0],ekf.colorekf)
hatch_below(ax3,sum_sims.dates,thres1_ss[:,0],LENSpcolor,sup/2)
minor_xticks = np.arange(t1s,t1s+t1l, 5)
ax3.set_xticks(minor_xticks, minor=True)
#ax3.plot(ekf.dates,thres1_ps2,'-',color=KFtacolor,linewidth=2)
#hatch_below(ax3,ekf.dates,thres1_ps2[:,0],KFtacolor,sup/2)
##30-year crossing point - changed label!!
##ax3.text(fineyears[x1crossavgyr],0.5,'$\\ast$',color='darkorange',ha='center',size=20)
ax3.plot(fineyears[x1crossavgyr],0.5,'.',color='yellow', markeredgecolor="mediumseagreen",zorder=3,alpha=1, markersize=18, markeredgewidth=3)
#ax3.text(fineyears[x1crossavgyr],0.5,'$\\ast$',color='darkorange',ha='center',size=20)





plt.sca(ax1)
#plt.plot(dates,xhatminus,'b-',label='a priori EKF estimate', linewidth=3.0)
ax1.set_title("Climate State Probabilities \n of Threshold Crossings with EBM-KF-uf",fontsize=bigfontb)


ax1.plot(ekf.dates,xh1s,'-',label='$\^{T }_{t}$', color=ekf.colorekf,linewidth=0.5)
ax1.fill_between(ekf.dates, xh1s-2*stdP, xh1s+2*stdP,label="$\pm 2\sqrt{\hat{p}^T_t}$", color=ekf.colorstate,alpha=0.3,zorder=0)
#plt.plot(ekf.dates,ekf.temps,'o',markersize=2,color=ekf.colorgrey,label='$Y_{t}$')


#plt.plot(ekf.dates,xh2s,'-',label='posterior pre-smoothed EBM-KF GMST state estimate', color='red')
#plt.fill_between(ekf.dates, xh2s-2*stdP2s, xh2s+2*stdP2s,label="95% CI of pre-smoothed GMST state", color='yellow',alpha=0.3,zorder=0)

ekf.plot_boilerplate(ax1)
ax1.set_xlim([lims[2],lims[3]])
ax1.set_ylim([lims[0],lims[1]])
ax1.set_yticks(np.arange(286.2,288.3,0.4))

draw_threshold_lines(ax1)
ax1.plot([lims[3]+100,lims[3]+100], [ekf.pindavg, ekf.pindavg], '-',color=ekf.colorekf,linewidth=2,label='$\mathbb{P}(\hat{T_t} ≥ 1\degree C )$')
handles, labels = plt.gca().get_legend_handles_labels()
order = [0,2,1] #[6,0,1,2,3,4,5]
ax1.legend([handles[idx] for idx in order],[labels[idx] for idx in order],prop={'size': bigfonts-1},loc="lower right",framealpha=1)

#plt.tick_params( direction = 'in',bottom=True, top=True, left=True, right=True )
#axb = ax.twinx()
#axb.set_yticks(np.arange(12.4,18.4,0.2))
#axb.set_ylim(lims[0]- 273.15, lims[1]- 273.15)
#axb.set_ylabel('Temperature (°C)')

#print(thres1_ps)
#print(thres2_ps)



#axthres1=cdf_subplot(ax,[thres1, boxsize, t1s,t1l],lims,ekf.dates,thres1_cps)
#axthres1.plot(ekf.dates,thres1_cps2,'-',color='maroon',linewidth=2)
axthres2=cdf_subplot(ax1,[thres2, boxsize, t2s,2025-t2s],lims,ekf.dates,thres2_cps,ekf.colorekf)

ax_dict=[ax1,ax12,ax4,ax2]
axlabels=['a)','b)','c)','d)','e)']
for i in range(4):
    label=axlabels[i]
    ax=ax_dict[i]
    trans = mtransforms.ScaledTranslation(-40/72, 5/72, fig.dpi_scale_trans) #-20/72, 7/72
    ax.text(0.0, 1.0, label, transform=ax.transAxes + trans, fontsize='large', va='bottom')

    
#axthres2.plot(ekf.dates,thres2_cps2,'-',color='maroon',linewidth=2)
### NEED TO CONTINUE FIX FIG 7 - can put in 2023 data too ###

fig.savefig("cdf1r.pdf",format="pdf")
fig2.savefig("cdf4r.pdf",format="pdf")
##fig2.savefig("cdf4.png", dpi=400,format="png")
##fig11.savefig("cdf1_lens.pdf",format="pdf")
##fig11.savefig("cdf1_lens.png", dpi=400,format="png")
##fig12.savefig("cdf2_lens.pdf",format="pdf")
##fig12.savefig("cdf2_lens.png", dpi=400,format="png")
##fig21.savefig("cdf1_ekf.pdf",format="pdf")
##fig21.savefig("cdf1_ekf.png", dpi=400,format="png")
##fig22.savefig("cdf2_ekf.pdf",format="pdf")
##fig22.savefig("cdf2_ekf.png", dpi=400,format="png")

plt.show()
